import torch
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from os.path import join as pjoin
import torch.nn.functional as F

import torch.optim as optim

import time
import numpy as np
from collections import OrderedDict, defaultdict
# from utils.eval_t2m import evaluation_vqvae, evaluation_res_conv
from utils.utils import print_current_loss
from models_interhuman_selfattn.losses import *

from eval import evaluation_during_training
import os
import sys

from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
import math
import os
def def_value():
    return 0.0
def plot_result_figure(path):
    for i,file in enumerate(os.listdir(path)):
        try:
            if file.endswith("png"):
                continue
            event_path = path+"/"+file
            # 创建 EventAccumulator 实例
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()  # 加载数据

            # 查看有哪些 scalars（比如 loss, accuracy）
            print(ea.Tags())

            # 2. 获取所有 scalar 的名字
            scalar_tags = ea.Tags().get('scalars', [])
            print("找到的所有标量：", scalar_tags)

            # 3. 计算子图布局（比如自动计算几行几列）
            num_plots = len(scalar_tags)
            cols = 3  # 每行放3个子图
            rows = math.ceil(num_plots / cols)
            plt.clf()
            # 4. 创建子图
            plt.figure(i+1)
            fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
            axes = axes.flatten()  # 方便索引

            # 5. 遍历每个loss，画到对应子图
            for idx, tag in enumerate(scalar_tags):
                events = ea.Scalars(tag)
                steps = [e.step for e in events]
                values = [e.value for e in events]

                ax = axes[idx]
                ax.plot(steps, values, label=tag)
                ax.set_title(tag)
                ax.set_xlabel('Step')
                ax.set_ylabel('Value')
                ax.grid(True)
                ax.legend()

            # 如果子图数量比tag多，关掉多余的空白子图
            for j in range(len(scalar_tags), len(axes)):
                fig.delaxes(axes[j])

            # 6. 布局调整+显示
            plt.tight_layout()
            plt.savefig(f"{path}/tensorboard_{file}.png")  # 保存图片
            plt.show()
        except Exception as e:
            print(f"Error processing {file}: {e}")
def save_result_txt(path,output_txt_path,epoch):
    results = []
    for i,file in enumerate(os.listdir(path)):
        try:
            if file.endswith("png"):
                continue
            event_path = path+"/"+file
            # 创建 EventAccumulator 实例
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()  # 加载数据

            # 查看有哪些 scalars（比如 loss, accuracy）
            print(ea.Tags())

            # 2. 获取所有 scalar 的名字
            scalar_tags = ea.Tags().get('scalars', [])
            print("找到的所有标量：", scalar_tags)

            # 3. 计算子图布局（比如自动计算几行几列）
            num_plots = len(scalar_tags)
            cols = 3  # 每行放3个子图
            rows = math.ceil(num_plots / cols)
            
            # 4. 创建子图
            exp_name = os.path.basename(file)
            line = f"{exp_name}_{epoch}:"
            # 5. 遍历每个loss，画到对应子图
            for idx, tag in enumerate(scalar_tags):
                events = ea.Scalars(tag)
                if not events:
                    continue
                latest_value = events[-1].value
                line += f" {tag} {latest_value:.6f}"
            results.append(line)
        
            print(f"{line}") 
        except Exception as e:
            print(f"Error processing {file}: {e}")

    with open(output_txt_path, "a+") as f:
        for line in results:
            f.write(line + "\n")



class RVQTokenizerTrainer:
    def __init__(self, args, vq_model, transformer=None):
        self.opt = args
        self.vq_model = vq_model
        self.device = args.device
        self.accumulation_steps=self.opt.accumulation_steps
        if args.is_train:
            self.logger = SummaryWriter(args.log_dir)
            self.geo_losses = Geometric_Losses(args,args.recons_loss,
                                               self.opt.conv_dim, 
                                               self.opt.joints_num,
                                               self.opt.dataset_name,
                                               self.device)
            
            self.inter_losses = Inter_Losses(args,args.recons_loss,
                                            self.opt.joints_num,
                                            self.opt.dataset_name,
                                            self.device)
        
        if transformer is not None:
            self.trans = transformer

    def forward(self, batch_data):
        motions1,motions2 = batch_data[0].detach().to(self.device).float(),batch_data[1].detach().to(self.device).float()
        # motions = batch_data.detach().to(self.device).float()
        Pose_motion,Pyhsical_motion, loss_commit, perplexity = self.vq_model(motions1,motions2, verbose=True)
        if not self.opt.physical_decoder:
            loss_rec1, loss_explicit1, loss_vel1, loss_bn1, loss_geo1, loss_fc1, loss_center1, loss_trajectory1 = self.geo_losses.forward(motions1, Pose_motion[0])
            loss_rec2, loss_explicit2, loss_vel2, loss_bn2, loss_geo2, loss_fc2, loss_center2, loss_trajectory2 = self.geo_losses.forward(motions2, Pose_motion[1])
            loss_center,loss_rec, loss_explicit, loss_vel, loss_bn, loss_geo, loss_fc,loss_trajectory=(loss_center1+loss_center2)/2,(loss_rec1+loss_rec2)/2,(loss_explicit1+loss_explicit2)/2,(loss_vel1+loss_vel2)/2,(loss_bn1+loss_bn2)/2,(loss_geo1+loss_geo2)/2,(loss_fc1+loss_fc2)/2,(loss_trajectory1+loss_trajectory2)/2
            Pose_loss = loss_rec + (self.opt.commit * loss_commit[0]) + (self.opt.loss_explicit * loss_explicit) + \
                (self.opt.loss_vel * loss_vel) + (self.opt.loss_bn * loss_bn) + (self.opt.loss_geo * loss_geo) + \
                    (self.opt.loss_fc * loss_fc)+self.opt.lambda_center*loss_center+self.opt.lambda_traj*loss_trajectory
            dm_loss, ro_loss,pen_loss=self.inter_losses.forward(motions1,motions2,Pose_motion[0],Pose_motion[1])
            inter_loss=dm_loss+ro_loss+pen_loss
            Pysical_loss=inter_loss
            loss=Pose_loss+self.opt.Inter_factor*Pysical_loss
            return loss,Pysical_loss,Pose_loss,dm_loss,ro_loss,loss_center,loss_trajectory,pen_loss
        else:
            loss_rec1, loss_explicit1, loss_vel1, loss_bn1, loss_geo1, loss_fc1, loss_center1, loss_trajectory1 = self.geo_losses.forward(motions1, Pose_motion[0])
            loss_rec2, loss_explicit2, loss_vel2, loss_bn2, loss_geo2, loss_fc2, loss_center2, loss_trajectory2 = self.geo_losses.forward(motions2, Pose_motion[1])
            loss_center,loss_rec, loss_explicit, loss_vel, loss_bn, loss_geo, loss_fc,loss_trajectory=(loss_center1+loss_center2)/2,(loss_rec1+loss_rec2)/2,(loss_explicit1+loss_explicit2)/2,(loss_vel1+loss_vel2)/2,(loss_bn1+loss_bn2)/2,(loss_geo1+loss_geo2)/2,(loss_fc1+loss_fc2)/2,(loss_trajectory1+loss_trajectory2)/2
            Pose_loss = loss_rec + (self.opt.commit * loss_commit[0]) + (self.opt.loss_explicit * loss_explicit) + \
                (self.opt.loss_vel * loss_vel) + (self.opt.loss_bn * loss_bn) + (self.opt.loss_geo * loss_geo) + \
                    (self.opt.loss_fc * loss_fc)+self.opt.lambda_center*loss_center+self.opt.lambda_traj*loss_trajectory
            loss_rec3, _, _,_, _, _, _, _ = self.geo_losses.forward(motions1, Pyhsical_motion[0])
            loss_rec4, _, _,_, _, _, _, _ = self.geo_losses.forward(motions2, Pyhsical_motion[1])
            dm_loss, ro_loss,pen_loss=self.inter_losses.forward(motions1,motions2,Pyhsical_motion[0],Pyhsical_motion[1])
            inter_loss=dm_loss+ro_loss*0.01+pen_loss*0.1
            Pysical_loss=(loss_rec3+loss_rec4)/2+inter_loss+loss_commit[1]*self.opt.commit
            loss=Pose_loss+self.opt.Inter_factor*Pysical_loss
            return loss,Pysical_loss,Pose_loss,dm_loss,ro_loss,loss_center,loss_trajectory,pen_loss


    # @staticmethod
    def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):

        current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
        for param_group in self.opt_vq_model.param_groups:
            param_group["lr"] = current_lr

        return current_lr

    def save(self, file_name, ep, total_it):
        state = {
            "vq_model": self.vq_model.state_dict(),
            "opt_vq_model": self.opt_vq_model.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            'ep': ep,
            'total_it': total_it,
        }
        torch.save(state, file_name)

    def resume(self, model_dir):
        checkpoint = torch.load(model_dir, map_location=self.device)
        self.vq_model.load_state_dict(checkpoint['vq_model'])
        self.opt_vq_model.load_state_dict(checkpoint['opt_vq_model'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        return checkpoint['ep'], checkpoint['total_it']

    def train(self, train_loader, val_loader, test_loader, eval_wrapper, plot_eval=None):
        self.vq_model.to(self.device)

        total_iters = self.opt.max_epoch * len(train_loader)
        print(f'\nTotal Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
        print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
        self.opt.warm_up_iter = len(train_loader)//4
        self.opt.log_every = len(train_loader)//10
        self.opt.save_latest = len(train_loader)//2
        print(f'Warm Up Iters: {self.opt.warm_up_iter}, Log Every: {self.opt.log_every} iters, Save every: {self.opt.save_latest} iters')
        
        self.opt.milestones = [int(total_iters*0.7), int(total_iters*0.85)]
        print(f"LR milestones: {self.opt.milestones}\n")

        self.opt_vq_model = optim.AdamW(self.vq_model.parameters(), lr=self.opt.lr, betas=(0.9, 0.99), weight_decay=self.opt.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt_vq_model, milestones=self.opt.milestones, gamma=self.opt.gamma)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, 'latest.tar')
            epoch, it = self.resume(model_dir)
            print("Load model epoch:%d iterations:%d"%(epoch, it))

        start_time = time.time()
        
        current_lr = self.opt.lr
        logs = defaultdict(def_value, OrderedDict())

        min_val_loss = np.inf
        min_fid = np.inf
        max_top1 = -np.inf

        if self.opt.do_eval:
                eval_file = pjoin(self.opt.eval_dir, 'evaluation_training.log')
        
        self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
        
        while epoch <= self.opt.max_epoch:
            epoch += 1
            self.vq_model.train()
            for i, batch_data in enumerate(train_loader):
                it += 1
                if it < self.opt.warm_up_iter:
                    current_lr = self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
                loss,Pysical_loss,Pose_loss,dm_loss,ro_loss,loss_center,loss_trajectory,loss_pen = self.forward(batch_data)
                if (i % self.accumulation_steps) == 0:
                    self.opt_vq_model.zero_grad()
                loss/=self.accumulation_steps
                loss.backward()
                if (i + 1) % self.accumulation_steps == 0 or (i + 1) == len(train_loader):
                    clip_grad_norm_(self.vq_model.parameters(), max_norm=1.0)
                    self.opt_vq_model.step()
                    if it >= self.opt.warm_up_iter:
                        self.scheduler.step()
                
                logs['loss'] += loss.item()
                logs['Pysical_loss'] += Pysical_loss.item()
                logs['Pose_loss'] += Pose_loss.item()
                logs['dm_loss'] += dm_loss.item()
                logs['ro_loss'] += ro_loss.item()
                logs['loss_center'] += loss_center.item()
                logs['loss_trajectory'] += loss_trajectory.item()
                logs['pen_loss'] += loss_pen.item()
                logs['lr'] += self.opt_vq_model.param_groups[0]['lr']
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    for tag, value in logs.items():
                        self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = defaultdict(def_value, OrderedDict())
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            if epoch==50:
                self.save(pjoin(self.opt.model_dir, 'latest_50.tar'), epoch, it)

            print('Validation time:')
            self.vq_model.eval()
            val_loss_physical = []
            val_loss_pose = []
            val_loss_dm = []
            val_loss_ro = []
            val_loss = []
            val_loss_center = []
            val_loss_trajectory = []
            val_loss_pen = []
            with torch.no_grad():
                for i, batch_data in enumerate(val_loader):
                    loss,Pysical_loss,Pose_loss,dm_loss,ro_loss,loss_center,loss_trajectory,loss_pen = self.forward(batch_data)
                    val_loss_physical.append(Pysical_loss.item())
                    val_loss_pose.append(Pose_loss.item())
                    val_loss_dm.append(dm_loss.item())
                    val_loss_ro.append(ro_loss.item())
                    val_loss.append(loss.item())
                    val_loss_center.append(loss_center.item())
                    val_loss_trajectory.append(loss_trajectory.item())
                    val_loss_pen.append(loss_pen.item())


            self.logger.add_scalar('Val/loss', sum(val_loss) / len(val_loss), epoch)
            self.logger.add_scalar('Val/physical_loss', sum(val_loss_physical) / len(val_loss_physical), epoch)
            self.logger.add_scalar('Val/pose_loss', sum(val_loss_pose) / len(val_loss_pose), epoch)
            self.logger.add_scalar('Val/dm_loss', sum(val_loss_dm) / len(val_loss_dm), epoch)
            self.logger.add_scalar('Val/ro_loss', sum(val_loss_ro) / len(val_loss_ro), epoch)
            self.logger.add_scalar('Val/loss_center', sum(val_loss_center) / len(val_loss_center), epoch)
            self.logger.add_scalar('Val/loss_trajectory', sum(val_loss_trajectory) / len(val_loss_trajectory), epoch)
            self.logger.add_scalar('Val/pen_loss', sum(val_loss_pen) / len(val_loss_pen), epoch)
            print('Validation Loss: %.5f physical_loss: %.5f, pose_loss: %.5f, dm_loss: %.5f, ro_loss: %.5f, loss_center: %.5f, loss_trajectory: %.5f, loss_pen: %.5f' %
                  (sum(val_loss) / len(val_loss),sum(val_loss_physical) / len(val_loss_physical),sum(val_loss_pose) / len(val_loss_pose),
                   sum(val_loss_dm) / len(val_loss_dm),sum(val_loss_ro) / len(val_loss_ro),sum(val_loss_center) / len(val_loss_center),
                   sum(val_loss_trajectory) / len(val_loss_trajectory),
                   sum(val_loss_pen) / len(val_loss_pen)))
            
            

            if sum(val_loss) / len(val_loss) < min_val_loss:
                min_val_loss = sum(val_loss) / len(val_loss)
                self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
                print('Best Validation Model So Far!~')
            plot_result_figure(self.opt.log_dir)
            if self.opt.do_eval:
                self.vq_model.eval()
                fid, mat, top1 = evaluation_during_training(self.opt, self.vq_model, test_loader, eval_wrapper, epoch, eval_file)
                self.logger.add_scalar('Test/FID', fid, epoch)
                self.logger.add_scalar('Test/Matching', mat, epoch)
                self.logger.add_scalar('Test/Top1', top1, epoch)

                if fid < min_fid:
                    min_fid = fid
                    self.save(pjoin(self.opt.model_dir, 'best_fid.tar'), epoch, it)
                    print('Best FID Model So Far!~')
                if top1 > max_top1:
                    max_top1 = top1
                    self.save(pjoin(self.opt.model_dir, 'best_top1.tar'), epoch, it)
                    print('Best Top1 Model So Far!~')
                plot_result_figure(self.opt.log_dir)
                save_result_txt(self.opt.log_dir, self.opt.save_root+"/log.txt",epoch)
                print('\n')